{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build a Question/Answering system over SQL data\n",
    "\n",
    ":::info Prerequisites\n",
    "\n",
    "This guide assumes familiarity with the following concepts:\n",
    "\n",
    "- [Chat models](/docs/concepts/chat_models)\n",
    "- [Tools](/docs/concepts/tools)\n",
    "- [Agents](/docs/concepts/agents)\n",
    "- [LangGraph](/docs/concepts/architecture/#langgraph)\n",
    "\n",
    ":::\n",
    "\n",
    "Enabling a LLM system to query structured data can be qualitatively different from unstructured text data. Whereas in the latter it is common to generate text that can be searched against a vector database, the approach for structured data is often for the LLM to write and execute queries in a DSL, such as SQL. In this guide we'll go over the basic ways to create a Q&A system over tabular data in databases. We will cover implementations using both [chains](/docs/tutorials/sql_qa#chains) and [agents](/docs/tutorials/sql_qa#agents). These systems will allow us to ask a question about the data in a database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many times as it needs to answer the question.\n",
    "\n",
    "## ⚠️ Security note ⚠️\n",
    "\n",
    "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
    "\n",
    "\n",
    "## Architecture\n",
    "\n",
    "At a high-level, the steps of these systems are:\n",
    "\n",
    "1. **Convert question to SQL query**: Model converts user input to a SQL query.\n",
    "2. **Execute SQL query**: Execute the query.\n",
    "3. **Answer the question**: Model responds to user input using the query results.\n",
    "\n",
    "![sql_usecase.png](../../static/img/sql_usecase.png)\n",
    "\n",
    "## Setup\n",
    "\n",
    "First, get required packages and set environment variables:\n",
    "```bash npm2yarn\n",
    "npm i langchain @langchain/community @langchain/langgraph\n",
    "```\n",
    "\n",
    "```shell\n",
    "# Uncomment the below to use LangSmith. Not required, but recommended for debugging and observability.\n",
    "# export LANGSMITH_API_KEY=<your key>\n",
    "# export LANGSMITH_TRACING=true\n",
    "\n",
    "# Reduce tracing latency if you are not in a serverless environment\n",
    "# export LANGCHAIN_CALLBACKS_BACKGROUND=true\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample data\n",
    "\n",
    "The below example will use a SQLite connection with the Chinook database, which is a sample database that represents a digital media store. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook. You can also download and build the database via the command line:\n",
    "```bash\n",
    "curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db\n",
    "```\n",
    "\n",
    "Now, `Chinook.db` is in our directory and we can interface with it using the [SqlDatabase](https://api.js.langchain.com/classes/langchain.sql_db.SqlDatabase.html) class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{\"ArtistId\":1,\"Name\":\"AC/DC\"},{\"ArtistId\":2,\"Name\":\"Accept\"},{\"ArtistId\":3,\"Name\":\"Aerosmith\"},{\"ArtistId\":4,\"Name\":\"Alanis Morissette\"},{\"ArtistId\":5,\"Name\":\"Alice In Chains\"},{\"ArtistId\":6,\"Name\":\"Antônio Carlos Jobim\"},{\"ArtistId\":7,\"Name\":\"Apocalyptica\"},{\"ArtistId\":8,\"Name\":\"Audioslave\"},{\"ArtistId\":9,\"Name\":\"BackBeat\"},{\"ArtistId\":10,\"Name\":\"Billy Cobham\"}]\n"
     ]
    }
   ],
   "source": [
    "import { SqlDatabase } from \"langchain/sql_db\";\n",
    "import { DataSource } from \"typeorm\";\n",
    "\n",
    "const datasource = new DataSource({\n",
    "  type: \"sqlite\",\n",
    "  database: \"Chinook.db\",\n",
    "});\n",
    "const db = await SqlDatabase.fromDataSourceParams({\n",
    "  appDataSource: datasource,\n",
    "});\n",
    "\n",
    "await db.run(\"SELECT * FROM Artist LIMIT 10;\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n",
    "\n",
    "## Chains {#chains}\n",
    "\n",
    "Chains are compositions of predictable steps. In [LangGraph](/docs/concepts/architecture#langchainlanggraph), we can represent a chain via simple sequence of nodes. Let's create a sequence of steps that, given a question, does the following:\n",
    "- converts the question into a SQL query;\n",
    "- executes the query;\n",
    "- uses the result to answer the original question.\n",
    "\n",
    "There are scenarios not supported by this arrangement. For example, this system will execute a SQL query for any user input-- even \"hello\". Importantly, as we'll see below, some questions require more than one query to answer. We will address these scenarios in the Agents section.\n",
    "\n",
    "### Application state\n",
    "\n",
    "The LangGraph [state](https://langchain-ai.github.io/langgraphjs/concepts/low_level/#state) of our application controls what data is input to the application, transferred between steps, and output by the application.\n",
    "\n",
    "For this application, we can just keep track of the input question, generated query, query result, and generated answer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { Annotation } from \"@langchain/langgraph\";\n",
    "\n",
    "\n",
    "const InputStateAnnotation = Annotation.Root({\n",
    "  question: Annotation<string>,\n",
    "});\n",
    "\n",
    "\n",
    "const StateAnnotation = Annotation.Root({\n",
    "  question: Annotation<string>,\n",
    "  query: Annotation<string>,\n",
    "  result: Annotation<string>,\n",
    "  answer: Annotation<string>,\n",
    "});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we just need functions that operate on this state and populate its contents.\n",
    "\n",
    "### Convert question to SQL query\n",
    "\n",
    "The first step is to take the user input and convert it to a SQL query. To reliably obtain SQL queries (absent markdown formatting and explanations or clarifications), we will make use of LangChain's [structured output](/docs/concepts/structured_outputs/) abstraction.\n",
    "\n",
    "Let's select a chat model for our application:\n",
    "\n",
    "```{=mdx}\n",
    "import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
    "\n",
    "<ChatModelTabs customVarName=\"llm\" />\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "// @lc-docs-hide-cell\n",
    "import { ChatOpenAI } from '@langchain/openai';\n",
    "\n",
    "const llm = new ChatOpenAI({\n",
    "  model: \"gpt-4o-mini\",\n",
    "  temperature: 0,\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will pull a prompt from the [Prompt Hub](https://smith.langchain.com/hub) to instruct the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
      "\n",
      "Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n",
      "\n",
      "Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
      "\n",
      "Only use the following tables:\n",
      "{table_info}\n",
      "Question: {input}\n"
     ]
    }
   ],
   "source": [
    "import { pull } from \"langchain/hub\";\n",
    "import { ChatPromptTemplate } from \"@langchain/core/prompts\";\n",
    "\n",
    "const queryPromptTemplate = await pull<ChatPromptTemplate>(\"langchain-ai/sql-query-system-prompt\");\n",
    "\n",
    "queryPromptTemplate.promptMessages.forEach(message => {\n",
    "    console.log(message.lc_kwargs.prompt.template);\n",
    "});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The prompt includes several parameters we will need to populate, such as the SQL dialect and table schemas. LangChain's [SqlDatabase](https://api.js.langchain.com/classes/langchain.sql_db.SqlDatabase.html) object includes methods to help with this. Our `writeQuery` step will just populate these parameters and prompt a model to generate the SQL query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { z } from \"zod\";\n",
    "\n",
    "\n",
    "const queryOutput = z.object({\n",
    "  query: z.string().describe(\"Syntactically valid SQL query.\"),\n",
    "});\n",
    "\n",
    "const structuredLlm = llm.withStructuredOutput(queryOutput)\n",
    "\n",
    "\n",
    "const writeQuery = async (state: typeof InputStateAnnotation.State) => {\n",
    "  const promptValue = await queryPromptTemplate.invoke({\n",
    "      dialect: db.appDataSourceOptions.type,\n",
    "      top_k: 10,\n",
    "      table_info: await db.getTableInfo(),\n",
    "      input: state.question\n",
    "  })\n",
    "  const result = await structuredLlm.invoke(promptValue)\n",
    "  return { query: result.query }\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test it out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ query: \u001b[32m'SELECT COUNT(*) AS EmployeeCount FROM Employee;'\u001b[39m }\n"
     ]
    }
   ],
   "source": [
    "await writeQuery({ question: \"How many Employees are there?\" })"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Execute query\n",
    "\n",
    "**This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n",
    "\n",
    "To execute the query, we will load a tool from [langchain-community](/docs/concepts/architecture#langchaincommunity). Our `executeQuery` node will just wrap this tool:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { QuerySqlTool } from \"langchain/tools/sql\";\n",
    "\n",
    "const executeQuery = async (state: typeof StateAnnotation.State) => {\n",
    "  const executeQueryTool = new QuerySqlTool(db);\n",
    "  return { result: await executeQueryTool.invoke(state.query) }\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing this step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ result: \u001b[32m'[{\"EmployeeCount\":8}]'\u001b[39m }\n"
     ]
    }
   ],
   "source": [
    "await executeQuery({\n",
    "    question: \"\",\n",
    "    query: \"SELECT COUNT(*) AS EmployeeCount FROM Employee;\",\n",
    "    result: \"\",\n",
    "    answer: \"\"\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate answer\n",
    "\n",
    "Finally, our last step generates an answer to the question given the information pulled from the database:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "const generateAnswer = async (state: typeof StateAnnotation.State) => {\n",
    "  const promptValue = \n",
    "    \"Given the following user question, corresponding SQL query, \" +\n",
    "    \"and SQL result, answer the user question.\\n\\n\" +\n",
    "    `Question: ${state.question}\\n` +\n",
    "    `SQL Query: ${state.query}\\n` +\n",
    "    `SQL Result: ${state.result}\\n`;\n",
    "  const response = await llm.invoke(promptValue)\n",
    "  return { answer: response.content }\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Orchestrating with LangGraph\n",
    "\n",
    "Finally, we compile our application into a single `graph` object. In this case, we are just connecting the three steps into a single sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { StateGraph } from \"@langchain/langgraph\";\n",
    "\n",
    "const graphBuilder = new StateGraph({\n",
    "  stateSchema: StateAnnotation,\n",
    "})\n",
    "  .addNode(\"writeQuery\", writeQuery)\n",
    "  .addNode(\"executeQuery\", executeQuery)\n",
    "  .addNode(\"generateAnswer\", generateAnswer)\n",
    "  .addEdge(\"__start__\", \"writeQuery\")\n",
    "  .addEdge(\"writeQuery\", \"executeQuery\")\n",
    "  .addEdge(\"executeQuery\", \"generateAnswer\")\n",
    "  .addEdge(\"generateAnswer\", \"__end__\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "const graph = graphBuilder.compile()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LangGraph also comes with built-in utilities for visualizing the control flow of your application:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```javascript\n",
    "// Note: tslab only works inside a jupyter notebook. Don't worry about running this code yourself!\n",
    "import * as tslab from \"tslab\";\n",
    "\n",
    "const image = await graph.getGraph().drawMermaidPng();\n",
    "const arrayBuffer = await image.arrayBuffer();\n",
    "\n",
    "await tslab.display.png(new Uint8Array(arrayBuffer));\n",
    "```\n",
    "\n",
    "![graph_img_sql_qa](../../static/img/graph_img_sql_qa.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test our application! Note that we can stream the results of individual steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ question: \u001b[32m'How many employees are there?'\u001b[39m }\n",
      "\n",
      "====\n",
      "\n",
      "{\n",
      "  writeQuery: { query: \u001b[32m'SELECT COUNT(*) AS EmployeeCount FROM Employee;'\u001b[39m }\n",
      "}\n",
      "\n",
      "====\n",
      "\n",
      "{ executeQuery: { result: \u001b[32m'[{\"EmployeeCount\":8}]'\u001b[39m } }\n",
      "\n",
      "====\n",
      "\n",
      "{ generateAnswer: { answer: \u001b[32m'There are 8 employees.'\u001b[39m } }\n",
      "\n",
      "====\n",
      "\n"
     ]
    }
   ],
   "source": [
    "let inputs = { question: \"How many employees are there?\" }\n",
    "\n",
    "console.log(inputs)\n",
    "console.log(\"\\n====\\n\");\n",
    "for await (\n",
    "  const step of await graph.stream(inputs, {\n",
    "    streamMode: \"updates\",\n",
    "  })\n",
    ") {\n",
    "  console.log(step);\n",
    "  console.log(\"\\n====\\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check out the [LangSmith trace](https://smith.langchain.com/public/4cb42037-55cf-4da9-8b3a-8410482dbd32/r)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Human-in-the-loop\n",
    "\n",
    "LangGraph supports a number of features that can be useful for this workflow. One of them is [human-in-the-loop](https://langchain-ai.github.io/langgraphjs/concepts/human_in_the_loop/): we can interrupt our application before sensitive steps (such as the execution of a SQL query) for human review. This is enabled by LangGraph's [persistence](https://langchain-ai.github.io/langgraphjs/concepts/persistence/) layer, which saves run progress to your storage of choice. Below, we specify storage in-memory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { MemorySaver } from \"@langchain/langgraph\";\n",
    "\n",
    "const checkpointer = new MemorySaver();\n",
    "const graphWithInterrupt = graphBuilder.compile({\n",
    "    checkpointer: checkpointer,\n",
    "    interruptBefore: [\"executeQuery\"]\n",
    "});\n",
    "\n",
    "// Now that we're using persistence, we need to specify a thread ID\n",
    "// so that we can continue the run after review.\n",
    "const threadConfig = {\n",
    "    configurable: { thread_id: \"1\" },\n",
    "    streamMode: \"updates\" as const\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```javascript\n",
    "const image = await graphWithInterrupt.getGraph().drawMermaidPng();\n",
    "const arrayBuffer = await image.arrayBuffer();\n",
    "\n",
    "await tslab.display.png(new Uint8Array(arrayBuffer));\n",
    "```\n",
    "\n",
    "![graph_img_sql_qa_interrupt](../../static/img/graph_img_sql_qa_interrupt.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's repeat the same run, adding in a simple yes/no approval step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ question: \u001b[32m'How many employees are there?'\u001b[39m }\n",
      "\n",
      "====\n",
      "\n",
      "{\n",
      "  writeQuery: { query: \u001b[32m'SELECT COUNT(*) AS EmployeeCount FROM Employee;'\u001b[39m }\n",
      "}\n",
      "\n",
      "====\n",
      "\n",
      "---GRAPH INTERRUPTED---\n"
     ]
    }
   ],
   "source": [
    "console.log(inputs)\n",
    "console.log(\"\\n====\\n\");\n",
    "for await (\n",
    "  const step of await graphWithInterrupt.stream(inputs, threadConfig)\n",
    ") {\n",
    "  console.log(step);\n",
    "  console.log(\"\\n====\\n\");\n",
    "}\n",
    "\n",
    "// Will log when the graph is interrupted, after `executeQuery`.\n",
    "console.log(\"---GRAPH INTERRUPTED---\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The run interrupts before the query is executed. At this point, our application can handle an interaction with a user, who reviews the query.\n",
    "\n",
    "If approved, running the same thread with a `null` input will continue from where we left off. This is enabled by LangGraph's [persistence](https://langchain-ai.github.io/langgraphjs/concepts/persistence/) layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ executeQuery: { result: \u001b[32m'[{\"EmployeeCount\":8}]'\u001b[39m } }\n",
      "\n",
      "====\n",
      "\n",
      "{ generateAnswer: { answer: \u001b[32m'There are 8 employees.'\u001b[39m } }\n",
      "\n",
      "====\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for await (\n",
    "  const step of await graphWithInterrupt.stream(null, threadConfig)\n",
    ") {\n",
    "  console.log(step);\n",
    "  console.log(\"\\n====\\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See [this](https://langchain-ai.github.io/langgraphjs/concepts/human_in_the_loop/) LangGraph guide for more detail and examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Next steps\n",
    "\n",
    "For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n",
    "\n",
    "* [Prompting strategies](/docs/how_to/sql_prompting): Advanced prompt engineering techniques.\n",
    "* [Query checking](/docs/how_to/sql_query_checking): Add query validation and error handling.\n",
    "* [Large databases](/docs/how_to/sql_large_db): Techniques for working with large databases."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agents {#agents}\n",
    "\n",
    "[Agents](/docs/concepts/agents) leverage the reasoning capabilities of LLMs to make decisions during execution. Using agents allows you to offload additional discretion over the query generation and execution process. Although their behavior is less predictable than the above \"chain\", they feature some advantages:\n",
    "\n",
    "- They can query the database as many times as needed to answer the user question.\n",
    "- They can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
    "- They can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
    "\n",
    "\n",
    "Below we assemble a minimal SQL agent. We will equip it with a set of tools using LangChain's [SqlToolkit](https://api.js.langchain.com/classes/langchain.agents_toolkits_sql.SqlToolkit.html). Using LangGraph's [pre-built ReAct agent constructor](https://langchain-ai.github.io/langgraphjs/how-tos/create-react-agent/), we can do this in one line.\n",
    "\n",
    "The `SqlToolkit` includes tools that can:\n",
    "\n",
    "* Create and execute queries\n",
    "* Check query syntax\n",
    "* Retrieve table descriptions\n",
    "* ... and more"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\n",
      "  {\n",
      "    name: \u001b[32m'query-sql'\u001b[39m,\n",
      "    description: \u001b[32m'Input to this tool is a detailed and correct SQL query, output is a result from the database.\\n'\u001b[39m +\n",
      "      \u001b[32m'  If the query is not correct, an error message will be returned.\\n'\u001b[39m +\n",
      "      \u001b[32m'  If an error is returned, rewrite the query, check the query, and try again.'\u001b[39m\n",
      "  },\n",
      "  {\n",
      "    name: \u001b[32m'info-sql'\u001b[39m,\n",
      "    description: \u001b[32m'Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.\\n'\u001b[39m +\n",
      "      \u001b[32m'    Be sure that the tables actually exist by calling list-tables-sql first!\\n'\u001b[39m +\n",
      "      \u001b[32m'\\n'\u001b[39m +\n",
      "      \u001b[32m'    Example Input: \"table1, table2, table3.'\u001b[39m\n",
      "  },\n",
      "  {\n",
      "    name: \u001b[32m'list-tables-sql'\u001b[39m,\n",
      "    description: \u001b[32m'Input is an empty string, output is a comma-separated list of tables in the database.'\u001b[39m\n",
      "  },\n",
      "  {\n",
      "    name: \u001b[32m'query-checker'\u001b[39m,\n",
      "    description: \u001b[32m'Use this tool to double check if your query is correct before executing it.\\n'\u001b[39m +\n",
      "      \u001b[32m'    Always use this tool before executing a query with query-sql!'\u001b[39m\n",
      "  }\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "import { SqlToolkit } from \"langchain/agents/toolkits/sql\";\n",
    "\n",
    "const toolkit = new SqlToolkit(db, llm);\n",
    "\n",
    "const tools = toolkit.getTools();\n",
    "\n",
    "console.log(\n",
    "  tools.map((tool) => ({\n",
    "    name: tool.name,\n",
    "    description: tool.description,\n",
    "  }))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### System Prompt\n",
    "\n",
    "We will also want to load a system prompt for our agent. This will consist of instructions for how to behave."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are an agent designed to interact with a SQL database.\n",
      "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
      "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
      "You can order the results by a relevant column to return the most interesting examples in the database.\n",
      "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
      "You have access to tools for interacting with the database.\n",
      "Only use the below tools. Only use the information returned by the below tools to construct your final answer.\n",
      "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
      "\n",
      "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
      "\n",
      "To start you should ALWAYS look at the tables in the database to see what you can query.\n",
      "Do NOT skip this step.\n",
      "Then you should query the schema of the most relevant tables.\n"
     ]
    }
   ],
   "source": [
    "import { pull } from \"langchain/hub\";\n",
    "import { ChatPromptTemplate } from \"@langchain/core/prompts\";\n",
    "\n",
    "const systemPromptTemplate = await pull<ChatPromptTemplate>(\"langchain-ai/sql-agent-system-prompt\");\n",
    "\n",
    "console.log(systemPromptTemplate.promptMessages[0].lc_kwargs.prompt.template)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's populate the parameters highlighted in the prompt:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "const systemMessage = await systemPromptTemplate.format({\n",
    "    dialect: \"SQLite\", top_k: 5\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initializing agent\n",
    "\n",
    "We will use a prebuilt [LangGraph](/docs/concepts/architecture/#langgraph) agent to build our agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { createReactAgent } from \"@langchain/langgraph/prebuilt\";\n",
    "\n",
    "const agent = createReactAgent({ llm: llm, tools: tools, prompt: systemMessage });"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Consider how the agent responds to the below question:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{=mdx}\n",
    "<details>\n",
    "<summary>Expand for `prettyPrint` code.</summary>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { AIMessage, BaseMessage, isAIMessage } from \"@langchain/core/messages\";\n",
    "\n",
    "const prettyPrint = (message: BaseMessage) => {\n",
    "  let txt = `[${message._getType()}]: ${message.content}`;\n",
    "  if (\n",
    "    (isAIMessage(message) && message.tool_calls?.length) ||\n",
    "    0 > 0\n",
    "  ) {\n",
    "    const tool_calls = (message as AIMessage)?.tool_calls\n",
    "      ?.map((tc) => `- ${tc.name}(${JSON.stringify(tc.args)})`)\n",
    "      .join(\"\\n\");\n",
    "    txt += ` \\nTools: \\n${tool_calls}`;\n",
    "  }\n",
    "  console.log(txt);\n",
    "};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{=mdx}\n",
    "</details>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[human]: Which country's customers spent the most?\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- list-tables-sql({\"input\":\"\"})\n",
      "-----\n",
      "\n",
      "[tool]: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- info-sql({\"input\":\"Customer, Invoice, InvoiceLine\"})\n",
      "- info-sql({\"input\":\"Invoice\"})\n",
      "-----\n",
      "\n",
      "[tool]: CREATE TABLE Invoice (\n",
      "InvoiceId INTEGER NOT NULL, CustomerId INTEGER NOT NULL, InvoiceDate DATETIME NOT NULL, BillingAddress NVARCHAR(70) , BillingCity NVARCHAR(40) , BillingState NVARCHAR(40) , BillingCountry NVARCHAR(40) , BillingPostalCode NVARCHAR(10) , Total NUMERIC(10,2) NOT NULL) \n",
      "SELECT * FROM \"Invoice\" LIMIT 3;\n",
      " InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total\n",
      " 1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart null Germany 70174 1.98\n",
      " 2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo null Norway 0171 3.96\n",
      " 3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels null Belgium 1000 5.94\n",
      "\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- query-checker({\"input\":\"SELECT c.Country, SUM(i.Total) AS TotalSpent \\nFROM Customer c \\nJOIN Invoice i ON c.CustomerId = i.CustomerId \\nGROUP BY c.Country \\nORDER BY TotalSpent DESC \\nLIMIT 5;\"})\n",
      "-----\n",
      "\n",
      "[tool]: The SQL query you provided appears to be correct and does not contain any of the common mistakes listed. It properly joins the `Customer` and `Invoice` tables, groups the results by country, and orders the total spending in descending order while limiting the results to the top 5 countries.\n",
      "\n",
      "Here is the original query reproduced:\n",
      "\n",
      "```sql\n",
      "SELECT c.Country, SUM(i.Total) AS TotalSpent \n",
      "FROM Customer c \n",
      "JOIN Invoice i ON c.CustomerId = i.CustomerId \n",
      "GROUP BY c.Country \n",
      "ORDER BY TotalSpent DESC \n",
      "LIMIT 5;\n",
      "``` \n",
      "\n",
      "No changes are necessary.\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- query-sql({\"input\":\"SELECT c.Country, SUM(i.Total) AS TotalSpent \\nFROM Customer c \\nJOIN Invoice i ON c.CustomerId = i.CustomerId \\nGROUP BY c.Country \\nORDER BY TotalSpent DESC \\nLIMIT 5;\"})\n",
      "-----\n",
      "\n",
      "[tool]: [{\"Country\":\"USA\",\"TotalSpent\":523.0600000000003},{\"Country\":\"Canada\",\"TotalSpent\":303.9599999999999},{\"Country\":\"France\",\"TotalSpent\":195.09999999999994},{\"Country\":\"Brazil\",\"TotalSpent\":190.09999999999997},{\"Country\":\"Germany\",\"TotalSpent\":156.48}]\n",
      "-----\n",
      "\n",
      "[ai]: The countries whose customers spent the most are:\n",
      "\n",
      "1. **USA** - $523.06\n",
      "2. **Canada** - $303.96\n",
      "3. **France** - $195.10\n",
      "4. **Brazil** - $190.10\n",
      "5. **Germany** - $156.48\n",
      "-----\n",
      "\n"
     ]
    }
   ],
   "source": [
    "let inputs2 = { messages: [{ role: \"user\", content: \"Which country's customers spent the most?\" }] };\n",
    "\n",
    "for await (\n",
    "  const step of await agent.stream(inputs2, {\n",
    "    streamMode: \"values\",\n",
    "  })\n",
    ") {\n",
    "    const lastMessage = step.messages[step.messages.length - 1];\n",
    "    prettyPrint(lastMessage);\n",
    "    console.log(\"-----\\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use the [LangSmith trace](https://smith.langchain.com/public/f4313ba4-a93e-418b-b863-1c2626c330d1/r) to visualize these steps and associated metadata.\n",
    "\n",
    "Note that the agent executes multiple queries until it has the information it needs:\n",
    "1. List available tables;\n",
    "2. Retrieves the schema for three tables;\n",
    "3. Queries multiple of the tables via a join operation.\n",
    "\n",
    "The agent is then able to use the result of the final query to generate an answer to the original question.\n",
    "\n",
    "The agent can similarly handle qualitative questions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[human]: Describe the playlisttrack table\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- list-tables-sql({\"input\":\"\"})\n",
      "-----\n",
      "\n",
      "[tool]: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- info-sql({\"input\":\"PlaylistTrack\"})\n",
      "-----\n",
      "\n",
      "[tool]: CREATE TABLE PlaylistTrack (\n",
      "PlaylistId INTEGER NOT NULL, TrackId INTEGER NOT NULL) \n",
      "SELECT * FROM \"PlaylistTrack\" LIMIT 3;\n",
      " PlaylistId TrackId\n",
      " 1 3402\n",
      " 1 3389\n",
      " 1 3390\n",
      "\n",
      "-----\n",
      "\n",
      "[ai]: The `PlaylistTrack` table has the following schema:\n",
      "\n",
      "- **PlaylistId**: INTEGER (NOT NULL)\n",
      "- **TrackId**: INTEGER (NOT NULL)\n",
      "\n",
      "This table is used to associate tracks with playlists. Here are some sample rows from the table:\n",
      "\n",
      "| PlaylistId | TrackId |\n",
      "|------------|---------|\n",
      "| 1          | 3402    |\n",
      "| 1          | 3389    |\n",
      "| 1          | 3390    |\n",
      "-----\n",
      "\n"
     ]
    }
   ],
   "source": [
    "let inputs3 = { messages: [{ role: \"user\", content: \"Describe the playlisttrack table\" }] };\n",
    "\n",
    "for await (\n",
    "  const step of await agent.stream(inputs3, {\n",
    "    streamMode: \"values\",\n",
    "  })\n",
    ") {\n",
    "    const lastMessage = step.messages[step.messages.length - 1];\n",
    "    prettyPrint(lastMessage);\n",
    "    console.log(\"-----\\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dealing with high-cardinality columns\n",
    "\n",
    "In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
    "\n",
    "We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
    "\n",
    "First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total: 622\n",
      "\n",
      "Sample: AC/DC,Accept,Aerosmith,Alanis Morissette,Alice In Chains...\n"
     ]
    }
   ],
   "source": [
    "async function queryAsList(database: SqlDatabase, query: string): Promise<string[]> {\n",
    "  const res: Array<{ [key: string]: string }> = JSON.parse(\n",
    "    await database.run(query)\n",
    "  )\n",
    "    .flat()\n",
    "    .filter((el: { [key: string]: string } | null) => el != null);\n",
    "  const justValues: Array<string> = res.map((item) =>\n",
    "    Object.values(item)[0]\n",
    "      .replace(/\\b\\d+\\b/g, \"\")\n",
    "      .trim()\n",
    "  );\n",
    "  return justValues;\n",
    "}\n",
    "\n",
    "// Gather entities into a list\n",
    "let artists: string[] = await queryAsList(db, \"SELECT Name FROM Artist\");\n",
    "let albums: string[] = await queryAsList(db, \"SELECT Title FROM Album\");\n",
    "let properNouns = artists.concat(albums);\n",
    "\n",
    "console.log(`Total: ${properNouns.length}\\n`)\n",
    "console.log(`Sample: ${properNouns.slice(0, 5)}...`)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using this function, we can create a **retriever tool** that the agent can execute at its discretion.\n",
    "\n",
    "Let's select an [embeddings model](/docs/integrations/text_embedding/) and [vector store](/docs/integrations/vectorstores/) for this step:\n",
    "\n",
    "```{=mdx}\n",
    "import EmbeddingTabs from \"@theme/EmbeddingTabs\";\n",
    "\n",
    "<EmbeddingTabs/>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "// @lc-docs-hide-cell\n",
    "import { OpenAIEmbeddings } from \"@langchain/openai\";\n",
    "\n",
    "const embeddings = new OpenAIEmbeddings({model: \"text-embedding-3-large\"});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{=mdx}\n",
    "import VectorStoreTabs from \"@theme/VectorStoreTabs\";\n",
    "\n",
    "<VectorStoreTabs/>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "// @lc-docs-hide-cell\n",
    "import { MemoryVectorStore } from \"langchain/vectorstores/memory\";\n",
    "\n",
    "const vectorStore = new MemoryVectorStore(embeddings);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now construct a retrieval tool that can search over relevant proper nouns in the database:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "import { createRetrieverTool } from \"langchain/tools/retriever\";\n",
    "import { Document } from \"@langchain/core/documents\";\n",
    "\n",
    "\n",
    "const documents = properNouns.map(text => new Document({ pageContent: text }));\n",
    "await vectorStore.addDocuments(documents)\n",
    "\n",
    "const retriever = vectorStore.asRetriever(5);\n",
    "\n",
    "const retrieverTool = createRetrieverTool(retriever, {\n",
    "  name: \"searchProperNouns\",\n",
    "  description:\n",
    "    \"Use to look up values to filter on. Input is an approximate spelling \" +\n",
    "    \"of the proper noun, output is valid proper nouns. Use the noun most \" +\n",
    "    \"similar to the search.\"\n",
    "});"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try it out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alice In Chains\n",
      "\n",
      "Alanis Morissette\n",
      "\n",
      "Jagged Little Pill\n",
      "\n",
      "Angel Dust\n",
      "\n",
      "Amy Winehouse\n"
     ]
    }
   ],
   "source": [
    "console.log(await retrieverTool.invoke({ query: \"Alice Chains\" }))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This way, if the agent determines it needs to write a filter based on an artist along the lines of \"Alice Chains\", it can first use the retriever tool to observe relevant values of a column.\n",
    "\n",
    "Putting this together:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Add to system message\n",
    "let suffix = (\n",
    "    \"If you need to filter on a proper noun like a Name, you must ALWAYS first look up \" +\n",
    "    \"the filter value using the 'search_proper_nouns' tool! Do not try to \" +\n",
    "    \"guess at the proper name - use this function to find similar ones.\"\n",
    ")\n",
    "\n",
    "const system = systemMessage + suffix\n",
    "\n",
    "const updatedTools = tools.concat(retrieverTool)\n",
    "\n",
    "const agent2 = createReactAgent({ llm: llm, tools: updatedTools, prompt: system });"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[human]: How many albums does alis in chain have?\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- searchProperNouns({\"query\":\"alis in chain\"})\n",
      "-----\n",
      "\n",
      "[tool]: Alice In Chains\n",
      "\n",
      "Alanis Morissette\n",
      "\n",
      "Up An' Atom\n",
      "\n",
      "A-Sides\n",
      "\n",
      "Jagged Little Pill\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- query-sql({\"input\":\"SELECT COUNT(*) FROM albums WHERE artist_name = 'Alice In Chains'\"})\n",
      "-----\n",
      "\n",
      "[tool]: QueryFailedError: SQLITE_ERROR: no such table: albums\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- list-tables-sql({\"input\":\"\"})\n",
      "-----\n",
      "\n",
      "[tool]: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- info-sql({\"input\":\"Album\"})\n",
      "- info-sql({\"input\":\"Artist\"})\n",
      "-----\n",
      "\n",
      "[tool]: CREATE TABLE Artist (\n",
      "ArtistId INTEGER NOT NULL, Name NVARCHAR(120) ) \n",
      "SELECT * FROM \"Artist\" LIMIT 3;\n",
      " ArtistId Name\n",
      " 1 AC/DC\n",
      " 2 Accept\n",
      " 3 Aerosmith\n",
      "\n",
      "-----\n",
      "\n",
      "[ai]:  \n",
      "Tools: \n",
      "- query-sql({\"input\":\"SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"})\n",
      "-----\n",
      "\n",
      "[tool]: [{\"COUNT(*)\":1}]\n",
      "-----\n",
      "\n",
      "[ai]: Alice In Chains has released 1 album.\n",
      "-----\n",
      "\n"
     ]
    }
   ],
   "source": [
    "let inputs4 = { messages: [{ role: \"user\", content: \"How many albums does alis in chain have?\" }] };\n",
    "\n",
    "for await (\n",
    "  const step of await agent2.stream(inputs4, {\n",
    "    streamMode: \"values\",\n",
    "  })\n",
    ") {\n",
    "    const lastMessage = step.messages[step.messages.length - 1];\n",
    "    prettyPrint(lastMessage);\n",
    "    console.log(\"-----\\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, both in the streamed steps and in the [LangSmith trace](https://smith.langchain.com/public/8b14a4a4-c08b-4b85-8086-c050931ae03d/r), the agent used the `searchProperNouns` tool in order to check how to correctly query the database for this specific artist."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TypeScript",
   "language": "typescript",
   "name": "tslab"
  },
  "language_info": {
   "codemirror_mode": {
    "mode": "typescript",
    "name": "javascript",
    "typescript": true
   },
   "file_extension": ".ts",
   "mimetype": "text/typescript",
   "name": "typescript",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
